RBONN: Recurrent Bilinear Optimization for a Binary Neural Network
83
Algorithm 7 RBONN training.
Input: a minibatch of inputs and their labels, real-valued weights w, recurrent model
weights U, scaling factor matrix A, learning rates η1, η2 and η3.
Output: updated real-valued weights wt+1, updated scaling factor matrix At+1, and up-
dated recurrent model weights U t+1.
1: while Forward propagation do
2:
bwt ←sign(wt).
3:
bat
in ←sign(at
in).
4:
Features calculation using Eq. 6.36
5:
Loss calculation using Eq. 6.68
6: end while
7: while Backward propagation do
8:
Computing
∂L
∂At ,
∂L
∂wt , and
∂L
∂U t using Eq. 6.70, 6.72, and 3.136.
9:
Update At+1, wt+1, and U t+1 according to Eqs. 6.69, 6.44, and 6.50, respectively.
10: end while
where w′ = diag(∥w1∥1, · · · , ∥wCout∥1). And we judge when asynchronous convergence
occurs in optimization based on (¬D(w′
i)) ∧D(Ai) = 1, where the density function is
defined as
D(xi) =
1
if ranking(σ(x)i)>T ,
0
otherwise,
(3.134)
where T is defined by T = int(Cout×τ). τ is the hyperparameter that denotes the threshold.
σ(x)i denotes the i-th eigenvalue of diagonal matrix x, and xi denotes the i-th row of matrix
x. Finally, we define the optimization of U as
U t+1 = |U t −η3
∂L
∂U t |,
(3.135)
∂L
∂U t = ∂LS
∂wt ◦DReLU(wt−1, At),
(3.136)
where η3 is the learning rate of U. We elaborate on the RBONN training process outlined
in Algorithm 13.
3.8.3
Discussion
In this section, we first review the related methods on “gradient approximation” of BNNs,
then further discuss the difference of RBONN with the related methods and analyze the
effectiveness of the proposed RBONN.
In particular, BNN [99] directly unitizes the Straight-Through-Estimator in the training
stage to calculate the gradient of weights and activations as
∂bwi,j
∂wi,j
= 1|wi,j|<1, ∂bai,j
∂ai,j
= 1|ai,j|<1
(3.137)
which suffers from an obvious gradient mismatch between the gradient of the binarization
function. Intuitively, the Bi-Real Net [159] designs an approximate binarization function
that can help alleviate the gradient mismatch in backward propagation as
∂bai,j
∂ai,j
=
⎧
⎨
⎩
1.2 + 2ai,j,
−1 ≤ai,j < 0,
2 −2ai,j,
0 ≤ai,j < 1,
10,
otherwise,
(3.138)